{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# How to stream tool calls\n",
    "\n",
    "When tools are called in a streaming context, \n",
    "[message chunks](https://api.python.langchain.com/en/latest/messages/langchain_core.messages.ai.AIMessageChunk.html#langchain_core.messages.ai.AIMessageChunk) \n",
    "will be populated with [tool call chunk](https://api.python.langchain.com/en/latest/messages/langchain_core.messages.tool.ToolCallChunk.html#langchain_core.messages.tool.ToolCallChunk) \n",
    "objects in a list via the `.tool_call_chunks` attribute. A `ToolCallChunk` includes \n",
    "optional string fields for the tool `name`, `args`, and `id`, and includes an optional \n",
    "integer field `index` that can be used to join chunks together. Fields are optional \n",
    "because portions of a tool call may be streamed across different chunks (e.g., a chunk \n",
    "that includes a substring of the arguments may have null values for the tool name and id).\n",
    "\n",
    "Because message chunks inherit from their parent message class, an \n",
    "[AIMessageChunk](https://api.python.langchain.com/en/latest/messages/langchain_core.messages.ai.AIMessageChunk.html#langchain_core.messages.ai.AIMessageChunk) \n",
    "with tool call chunks will also include `.tool_calls` and `.invalid_tool_calls` fields. \n",
    "These fields are parsed best-effort from the message's tool call chunks.\n",
    "\n",
    "Note that not all providers currently support streaming for tool calls. Before we start let's define our tools and our model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.tools import tool\n",
    "\n",
    "\n",
    "@tool\n",
    "def add(a: int, b: int) -> int:\n",
    "    \"\"\"Adds a and b.\"\"\"\n",
    "    return a + b\n",
    "\n",
    "\n",
    "@tool\n",
    "def multiply(a: int, b: int) -> int:\n",
    "    \"\"\"Multiplies a and b.\"\"\"\n",
    "    return a * b\n",
    "\n",
    "\n",
    "tools = [add, multiply]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from getpass import getpass\n",
    "\n",
    "from langchain_openai import ChatOpenAI\n",
    "\n",
    "os.environ[\"OPENAI_API_KEY\"] = getpass()\n",
    "\n",
    "llm = ChatOpenAI(model=\"gpt-3.5-turbo-0125\", temperature=0)\n",
    "llm_with_tools = llm.bind_tools(tools)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's define our query and stream our output:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[]\n",
      "[{'name': 'Multiply', 'args': '', 'id': 'call_3aQwTP9CYlFxwOvQZPHDu6wL', 'index': 0}]\n",
      "[{'name': None, 'args': '{\"a\"', 'id': None, 'index': 0}]\n",
      "[{'name': None, 'args': ': 3, ', 'id': None, 'index': 0}]\n",
      "[{'name': None, 'args': '\"b\": 1', 'id': None, 'index': 0}]\n",
      "[{'name': None, 'args': '2}', 'id': None, 'index': 0}]\n",
      "[{'name': 'Add', 'args': '', 'id': 'call_SQUoSsJz2p9Kx2x73GOgN1ja', 'index': 1}]\n",
      "[{'name': None, 'args': '{\"a\"', 'id': None, 'index': 1}]\n",
      "[{'name': None, 'args': ': 11,', 'id': None, 'index': 1}]\n",
      "[{'name': None, 'args': ' \"b\": ', 'id': None, 'index': 1}]\n",
      "[{'name': None, 'args': '49}', 'id': None, 'index': 1}]\n",
      "[]\n"
     ]
    }
   ],
   "source": [
    "query = \"What is 3 * 12? Also, what is 11 + 49?\"\n",
    "\n",
    "async for chunk in llm_with_tools.astream(query):\n",
    "    print(chunk.tool_call_chunks)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that adding message chunks will merge their corresponding tool call chunks. This is the principle by which LangChain's various [tool output parsers](/docs/how_to/output_parser_structured) support streaming.\n",
    "\n",
    "For example, below we accumulate tool call chunks:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[]\n",
      "[{'name': 'Multiply', 'args': '', 'id': 'call_AkL3dVeCjjiqvjv8ckLxL3gP', 'index': 0}]\n",
      "[{'name': 'Multiply', 'args': '{\"a\"', 'id': 'call_AkL3dVeCjjiqvjv8ckLxL3gP', 'index': 0}]\n",
      "[{'name': 'Multiply', 'args': '{\"a\": 3, ', 'id': 'call_AkL3dVeCjjiqvjv8ckLxL3gP', 'index': 0}]\n",
      "[{'name': 'Multiply', 'args': '{\"a\": 3, \"b\": 1', 'id': 'call_AkL3dVeCjjiqvjv8ckLxL3gP', 'index': 0}]\n",
      "[{'name': 'Multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_AkL3dVeCjjiqvjv8ckLxL3gP', 'index': 0}]\n",
      "[{'name': 'Multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_AkL3dVeCjjiqvjv8ckLxL3gP', 'index': 0}, {'name': 'Add', 'args': '', 'id': 'call_b4iMiB3chGNGqbt5SjqqD2Wh', 'index': 1}]\n",
      "[{'name': 'Multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_AkL3dVeCjjiqvjv8ckLxL3gP', 'index': 0}, {'name': 'Add', 'args': '{\"a\"', 'id': 'call_b4iMiB3chGNGqbt5SjqqD2Wh', 'index': 1}]\n",
      "[{'name': 'Multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_AkL3dVeCjjiqvjv8ckLxL3gP', 'index': 0}, {'name': 'Add', 'args': '{\"a\": 11,', 'id': 'call_b4iMiB3chGNGqbt5SjqqD2Wh', 'index': 1}]\n",
      "[{'name': 'Multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_AkL3dVeCjjiqvjv8ckLxL3gP', 'index': 0}, {'name': 'Add', 'args': '{\"a\": 11, \"b\": ', 'id': 'call_b4iMiB3chGNGqbt5SjqqD2Wh', 'index': 1}]\n",
      "[{'name': 'Multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_AkL3dVeCjjiqvjv8ckLxL3gP', 'index': 0}, {'name': 'Add', 'args': '{\"a\": 11, \"b\": 49}', 'id': 'call_b4iMiB3chGNGqbt5SjqqD2Wh', 'index': 1}]\n",
      "[{'name': 'Multiply', 'args': '{\"a\": 3, \"b\": 12}', 'id': 'call_AkL3dVeCjjiqvjv8ckLxL3gP', 'index': 0}, {'name': 'Add', 'args': '{\"a\": 11, \"b\": 49}', 'id': 'call_b4iMiB3chGNGqbt5SjqqD2Wh', 'index': 1}]\n"
     ]
    }
   ],
   "source": [
    "first = True\n",
    "async for chunk in llm_with_tools.astream(query):\n",
    "    if first:\n",
    "        gathered = chunk\n",
    "        first = False\n",
    "    else:\n",
    "        gathered = gathered + chunk\n",
    "\n",
    "    print(gathered.tool_call_chunks)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'str'>\n"
     ]
    }
   ],
   "source": [
    "print(type(gathered.tool_call_chunks[0][\"args\"]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And below we accumulate tool calls to demonstrate partial parsing:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[]\n",
      "[]\n",
      "[{'name': 'Multiply', 'args': {}, 'id': 'call_4p0D4tHVXSiae9Mu0e8jlI1m'}]\n",
      "[{'name': 'Multiply', 'args': {'a': 3}, 'id': 'call_4p0D4tHVXSiae9Mu0e8jlI1m'}]\n",
      "[{'name': 'Multiply', 'args': {'a': 3, 'b': 1}, 'id': 'call_4p0D4tHVXSiae9Mu0e8jlI1m'}]\n",
      "[{'name': 'Multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_4p0D4tHVXSiae9Mu0e8jlI1m'}]\n",
      "[{'name': 'Multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_4p0D4tHVXSiae9Mu0e8jlI1m'}]\n",
      "[{'name': 'Multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_4p0D4tHVXSiae9Mu0e8jlI1m'}, {'name': 'Add', 'args': {}, 'id': 'call_54Hx3DGjZitFlEjgMe1DYonh'}]\n",
      "[{'name': 'Multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_4p0D4tHVXSiae9Mu0e8jlI1m'}, {'name': 'Add', 'args': {'a': 11}, 'id': 'call_54Hx3DGjZitFlEjgMe1DYonh'}]\n",
      "[{'name': 'Multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_4p0D4tHVXSiae9Mu0e8jlI1m'}, {'name': 'Add', 'args': {'a': 11}, 'id': 'call_54Hx3DGjZitFlEjgMe1DYonh'}]\n",
      "[{'name': 'Multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_4p0D4tHVXSiae9Mu0e8jlI1m'}, {'name': 'Add', 'args': {'a': 11, 'b': 49}, 'id': 'call_54Hx3DGjZitFlEjgMe1DYonh'}]\n",
      "[{'name': 'Multiply', 'args': {'a': 3, 'b': 12}, 'id': 'call_4p0D4tHVXSiae9Mu0e8jlI1m'}, {'name': 'Add', 'args': {'a': 11, 'b': 49}, 'id': 'call_54Hx3DGjZitFlEjgMe1DYonh'}]\n"
     ]
    }
   ],
   "source": [
    "first = True\n",
    "async for chunk in llm_with_tools.astream(query):\n",
    "    if first:\n",
    "        gathered = chunk\n",
    "        first = False\n",
    "    else:\n",
    "        gathered = gathered + chunk\n",
    "\n",
    "    print(gathered.tool_calls)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'dict'>\n"
     ]
    }
   ],
   "source": [
    "print(type(gathered.tool_calls[0][\"args\"]))"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
